package drds.plus.sql_process.optimizer.execute_plan_optimizer;

import drds.plus.common.jdbc.Parameters;
import drds.plus.common.properties.ConnectionProperties;
import drds.plus.sql_process.abstract_syntax_tree.ObjectCreateFactory;
import drds.plus.sql_process.abstract_syntax_tree.execute_plan.ExecutePlan;
import drds.plus.sql_process.abstract_syntax_tree.execute_plan.query.Join;
import drds.plus.sql_process.abstract_syntax_tree.execute_plan.query.MergeQuery;
import drds.plus.sql_process.abstract_syntax_tree.execute_plan.query.Query;
import drds.plus.sql_process.abstract_syntax_tree.execute_plan.query.QueryWithIndex;
import drds.plus.util.ExtraCmd;

import java.util.List;
import java.util.Map;

/**
 * 如果设置了MergeConcurrent 并且值为True，则将所有的Merge变为并行
 *
 * <pre>
 * TODO: 需要考虑 sort join query展开为 sort join merge后，对应的join的返回列，以及对应的filter需要重新build，应该基于语法树来做
 * </pre>
 */
public class MergeJoinExpandOptimizer implements ExecutePlanOptimizer {

    /**
     * 左右表是否为单库上的多表 join 单库上的单表/多表
     */
    private static boolean isNeedExpand(MergeQuery mergeQuery, Query query, Map<String, Object> extraCmd) {
        boolean expand = true;
        for (ExecutePlan executePlan : mergeQuery.getExecutePlanList()) {
            expand &= executePlan.getDataNodeId().equals(query.getDataNodeId());
            if (!expand) {
                return ExtraCmd.getExtraCmdBoolean(extraCmd, ConnectionProperties.MERGE_EXPAND, false);
            }
        }

        return ExtraCmd.getExtraCmdBoolean(extraCmd, ConnectionProperties.MERGE_EXPAND, true);
    }

    /**
     * 如果设置了MergeConcurrent 并且值为True，则将所有的Merge变为并行
     */

    public ExecutePlan optimize(ExecutePlan executePlan, Parameters parameters, Map<String, Object> extraCmd) {
        return this.findEveryJoin(executePlan, true, true, extraCmd);
    }

    private ExecutePlan findEveryJoin(ExecutePlan executePlan, boolean isExpandLeft, boolean isExpandRight, Map<String, Object> extraCmd) {
        if (executePlan instanceof MergeQuery) {
            List<ExecutePlan> subExecutePlanList = ((MergeQuery) executePlan).getExecutePlanList();
            for (int i = 0; i < subExecutePlanList.size(); i++) {
                subExecutePlanList.set(i, this.findEveryJoin(subExecutePlanList.get(i), isExpandLeft, isExpandRight, extraCmd));
            }

            ((MergeQuery) executePlan).setExecutePlanList(subExecutePlanList);
            return executePlan;
        } else if (executePlan instanceof QueryWithIndex) {
            return executePlan;
        } else if (executePlan instanceof Join) {
            ((Join) executePlan).setLeftNode((Query) this.findEveryJoin(((Join) executePlan).getLeftNode(), isExpandLeft, isExpandRight, extraCmd));
            ((Join) executePlan).setRightNode((Query) this.findEveryJoin(((Join) executePlan).getRightNode(), isExpandLeft, isExpandRight, extraCmd));
            return this.processJoin((Join) executePlan, isExpandLeft, isExpandRight, extraCmd);
        }

        return executePlan;
    }

    private Query processJoin(Join join, boolean isExpandLeft, boolean isExpandRight, Map<String, Object> extraCmd) {
        // 如果一个节点包含limit，group by，order by等条件，则不能展开
        if (!canExpand(join)) {
            // join节点可能自己存在limit
            isExpandLeft = false;
            isExpandRight = false;
        } else if (!canExpand(join.getLeftNode())) {
            isExpandLeft = false;
        } else if (!canExpand(join.getRightNode())) {
            isExpandRight = false;
        }

        if (isExpandLeft && isExpandRight) {
            return this.cartesianProduct(join, extraCmd);
        } else if (isExpandLeft) {
            return this.expandLeft(join, extraCmd);
        } else if (isExpandRight) {
            return this.expandRight(join, extraCmd);
        } else {
            return join;
        }
    }

    private boolean canExpand(Query query) {
        // 如果一个节点包含limit，group by，order by等条件
        return query.getLimitFrom() == null && query.getLimitTo() == null && !query.isExistAggregate();
    }

    /**
     * 将左边的merge展开，依次和右边做join
     */
    public Query expandLeft(Join join, Map<String, Object> extraCmd) {
        if (!(join.getLeftNode() instanceof MergeQuery)) {
            return join;
        }

        MergeQuery mergeQuery = (MergeQuery) join.getLeftNode();
        if (!isNeedExpand(mergeQuery, join.getRightNode(), extraCmd)) {
            return join;
        }

        MergeQuery newMergeQuery = ObjectCreateFactory.createMergeQuery();
        for (ExecutePlan executePlan : mergeQuery.getExecutePlanList()) {
            Join newJoin = (Join) join.copy();
            newJoin.setLeftNode((Query) executePlan);
            newJoin.setRightNode(join.getRightNode());
            newJoin.setDataNodeId(join.getDataNodeId());
            newMergeQuery.addExecutePlan(newJoin);
        }

        newMergeQuery.setAlias(join.getAlias());
        newMergeQuery.setSelectItemList(join.getItemList());
        newMergeQuery.setConsistent(join.isConsistent());
        newMergeQuery.setGroupByList(join.getGroupByList());
        newMergeQuery.setLimitFrom(join.getLimitFrom());
        newMergeQuery.setLimitTo(join.getLimitTo());
        newMergeQuery.setOrderByList(join.getOrderByList());
        newMergeQuery.setQueryConcurrency(join.getQueryConcurrencyWay());
        newMergeQuery.having(join.getHaving());
        newMergeQuery.setValueFilter(join.getValueFilter());
        newMergeQuery.setOtherJoinOnFilter(join.getOtherJoinOnFilter());
        newMergeQuery.setDataNodeId(join.getDataNodeId());
        newMergeQuery.setExistAggregate(join.isExistAggregate());
        newMergeQuery.setIsSubQuery(join.isSubQuery());
        return newMergeQuery;
    }

    /**
     * 将右边的merge展开，依次和左边做join
     */
    public Query expandRight(Join join, Map<String, Object> extraCmd) {
        if (!(join.getRightNode() instanceof MergeQuery)) {
            return join;
        }

        MergeQuery mergeQuery = (MergeQuery) join.getRightNode();
        if (!isNeedExpand(mergeQuery, join.getLeftNode(), extraCmd)) {
            return join;
        }

        MergeQuery newMergeQuery = ObjectCreateFactory.createMergeQuery();
        for (ExecutePlan executePlan : mergeQuery.getExecutePlanList()) {
            Join newJoin = (Join) join.copy();
            newJoin.setLeftNode(join.getLeftNode());
            ((Query) executePlan).setAlias(mergeQuery.getAlias());
            newJoin.setRightNode((Query) executePlan);
            newJoin.setDataNodeId(join.getDataNodeId());
            newMergeQuery.addExecutePlan(newJoin);
        }

        newMergeQuery.setAlias(join.getAlias());
        newMergeQuery.setSelectItemList(join.getItemList());
        newMergeQuery.setConsistent(join.isConsistent());
        newMergeQuery.setGroupByList(join.getGroupByList());
        newMergeQuery.having(join.getHaving());
        newMergeQuery.setLimitFrom(join.getLimitFrom());
        newMergeQuery.setLimitTo(join.getLimitTo());
        newMergeQuery.setOrderByList(join.getOrderByList());
        newMergeQuery.setQueryConcurrency(join.getQueryConcurrencyWay());
        newMergeQuery.setValueFilter(join.getValueFilter());
        newMergeQuery.setDataNodeId(join.getDataNodeId());
        newMergeQuery.setOtherJoinOnFilter(join.getOtherJoinOnFilter());
        newMergeQuery.setExistAggregate(join.isExistAggregate());
        newMergeQuery.setIsSubQuery(join.isSubQuery());
        return newMergeQuery;
    }

    /**
     * 左右都展开做笛卡尔积
     */
    public Query cartesianProduct(Join join, Map<String, Object> extraCmd) {
        if (join.getLeftNode() instanceof MergeQuery && !(join.getRightNode() instanceof MergeQuery)) {
            return this.expandLeft(join, extraCmd);
        }

        if (!(join.getLeftNode() instanceof MergeQuery) && (join.getRightNode() instanceof MergeQuery)) {
            return this.expandRight(join, extraCmd);
        }

        if (!(join.getLeftNode() instanceof MergeQuery) && !(join.getRightNode() instanceof MergeQuery)) {
            return join;
        }

        if (!ExtraCmd.getExtraCmdBoolean(extraCmd, ConnectionProperties.MERGE_EXPAND, false)) {
            return join;
        }

        MergeQuery leftMergeQuery = (MergeQuery) join.getLeftNode();
        MergeQuery rightMergeQuery = (MergeQuery) join.getRightNode();
        MergeQuery newMergeQuery = ObjectCreateFactory.createMergeQuery();

        for (ExecutePlan leftChild : leftMergeQuery.getExecutePlanList()) {
            for (ExecutePlan rightChild : rightMergeQuery.getExecutePlanList()) {
                Join newJoin = (Join) join.copy();
                newJoin.setLeftNode((Query) leftChild);
                newJoin.setRightNode((Query) rightChild);
                newJoin.setDataNodeId(leftChild.getDataNodeId());
                newMergeQuery.addExecutePlan(newJoin);
            }
        }
        newMergeQuery.setAlias(join.getAlias());
        newMergeQuery.setSelectItemList(join.getItemList());
        newMergeQuery.setConsistent(join.isConsistent());
        newMergeQuery.setGroupByList(join.getGroupByList());
        newMergeQuery.having(join.getHaving());
        newMergeQuery.setLimitFrom(join.getLimitFrom());
        newMergeQuery.setLimitTo(join.getLimitTo());
        newMergeQuery.setOrderByList(join.getOrderByList());
        newMergeQuery.setQueryConcurrency(join.getQueryConcurrencyWay());
        newMergeQuery.setValueFilter(join.getValueFilter());
        newMergeQuery.setDataNodeId(join.getDataNodeId());
        newMergeQuery.setOtherJoinOnFilter(join.getOtherJoinOnFilter());
        newMergeQuery.setExistAggregate(join.isExistAggregate());
        newMergeQuery.setIsSubQuery(join.isSubQuery());
        return newMergeQuery;
    }

}
